"""
WebSocket Authentication Middleware for Django Channels

This middleware provides authentication for WebSocket connections using:
1. Token authentication (for web clients)
2. Device key authentication (for ESP32 devices)
"""

import logging
from urllib.parse import parse_qs
from channels.db import database_sync_to_async
from channels.middleware import BaseMiddleware
from django.contrib.auth.models import AnonymousUser
from rest_framework.authtoken.models import Token
from django.core.exceptions import ObjectDoesNotExist

logger = logging.getLogger(__name__)


class TokenAuthMiddleware(BaseMiddleware):
    """
    Custom middleware for WebSocket token authentication.
    
    Supports two authentication methods:
    1. Token authentication: ?token=<user_token>
    2. Device authentication: ?device_key=<device_secret>
    """
    
    def __init__(self, inner):
        super().__init__(inner)

    async def __call__(self, scope, receive, send):
        """
        Authenticate WebSocket connection and add user to scope.
        """
        # Only process WebSocket connections
        if scope['type'] != 'websocket':
            return await self.inner(scope, receive, send)
        
        # Parse query string for authentication parameters
        query_string = scope.get('query_string', b'').decode()
        query_params = parse_qs(query_string)
        
        # Initialize user and device flags
        scope['user'] = AnonymousUser()
        scope['is_device'] = False
        scope['device_authenticated'] = False
        
        # Try token authentication first (for web clients)
        token = query_params.get('token', [None])[0]
        if token:
            user = await self.get_user_from_token(token)
            if user:
                scope['user'] = user
                logger.info(f"WebSocket: User {user.username} authenticated via token")
            else:
                logger.warning(f"WebSocket: Invalid token provided: {token}")
        
        # Try device key authentication (for ESP32 devices)
        device_key = query_params.get('device_key', [None])[0]
        device_id = self.extract_device_id_from_scope(scope)
        
        if device_key and device_id:
            if await self.authenticate_device(device_id, device_key):
                scope['is_device'] = True
                scope['device_authenticated'] = True
                scope['device_id'] = device_id
                logger.info(f"WebSocket: Device {device_id} authenticated via device key")
            else:
                logger.warning(f"WebSocket: Invalid device key for device {device_id}")
        
        return await self.inner(scope, receive, send)

    @database_sync_to_async
    def get_user_from_token(self, token_key):
        """
        Get user from authentication token.
        """
        try:
            token = Token.objects.get(key=token_key)
            return token.user
        except ObjectDoesNotExist:
            return None

    async def authenticate_device(self, device_id, device_key):
        """
        Authenticate device using device key.
        
        In a production environment, you should:
        1. Store device keys in a database with proper encryption
        2. Implement key rotation mechanisms
        3. Add device registration/deregistration workflows
        """
        # For demonstration, we use a simple pattern
        # In production, replace this with database lookup
        expected_key = f"device_{device_id}_secret_key"
        
        # You could also implement more sophisticated authentication:
        # - JWT tokens for devices
        # - Certificate-based authentication
        # - Time-based one-time passwords (TOTP)
        
        return device_key == expected_key

    def extract_device_id_from_scope(self, scope):
        """
        Extract device ID from WebSocket URL path.
        """
        try:
            # Assuming URL pattern: /ws/device/<device_id>/
            path = scope.get('path', '')
            if '/ws/device/' in path:
                parts = path.strip('/').split('/')
                if len(parts) >= 3 and parts[0] == 'ws' and parts[1] == 'device':
                    return parts[2]
        except Exception as e:
            logger.error(f"Error extracting device ID from scope: {e}")
        return None


class DeviceKeyAuthMiddleware:
    """
    Alternative middleware specifically for device authentication.
    This can be used if you want separate authentication logic for devices.
    """
    
    def __init__(self, inner):
        self.inner = inner

    async def __call__(self, scope, receive, send):
        """
        Device-specific authentication middleware.
        """
        if scope['type'] != 'websocket':
            return await self.inner(scope, receive, send)
        
        # Extract device information
        query_string = scope.get('query_string', b'').decode()
        query_params = parse_qs(query_string)
        
        device_key = query_params.get('device_key', [None])[0]
        device_id = query_params.get('device_id', [None])[0]
        
        # Validate device credentials
        if device_key and device_id:
            if await self.validate_device_credentials(device_id, device_key):
                scope['device_authenticated'] = True
                scope['device_id'] = device_id
                logger.info(f"Device {device_id} authenticated successfully")
            else:
                logger.warning(f"Device {device_id} authentication failed")
                # You might want to close the connection here
        
        return await self.inner(scope, receive, send)

    async def validate_device_credentials(self, device_id, device_key):
        """
        Validate device credentials against database or configuration.
        """
        # Implement your device validation logic here
        # This could involve:
        # - Database lookup for registered devices
        # - Certificate validation
        # - API key validation
        
        # Simple example implementation
        valid_devices = {
            'esp32_001': 'device_esp32_001_secret_key',
            'esp32_002': 'device_esp32_002_secret_key',
            'door_controller_01': 'device_door_controller_01_secret_key',
        }
        
        return valid_devices.get(device_id) == device_key


def TokenAuthMiddlewareStack(inner):
    """
    Middleware stack that includes token authentication.
    Use this in your ASGI routing configuration.
    """
    return TokenAuthMiddleware(inner)


def DeviceAuthMiddlewareStack(inner):
    """
    Middleware stack for device-only authentication.
    """
    return DeviceKeyAuthMiddleware(inner)
